"""Custom loss classes for probing tasks."""

import torch as th
import torch.nn as nn


class L1DistanceLoss(nn.Module):
    """Custom L1 loss for distance matrices."""

    def __init__(self):
        super().__init__()
        self.word_pair_dims = (1, 2)

    def forward(self, predictions, label_batch, length_batch):
        """ 
        Computes L1 loss on distance matrices.

        Ignores all entries where label_batch=-1
        Normalizes first within sentences (by dividing by the square of 
        the sentence length) and then across the batch.

        Args:
            predictions: A pytorch batch of predicted distances
            label_batch: A pytorch batch of true distances
            length_batch: A pytorch batch of sentence lengths

        Returns:
            A tuple of:
                batch_loss: average loss in the batch
                total_sents: number of sentences in the batch
        """
        labels_1s = (label_batch != -1).float()
        predictions_masked = predictions * labels_1s
        labels_masked = label_batch * labels_1s
        total_sents = th.sum((length_batch != 0)).float()
        squared_lengths = length_batch.pow(2).float()
        if total_sents > 0:
            loss_per_sent = th.sum(
                th.abs(predictions_masked - labels_masked), dim=self.word_pair_dims
            )
            normalized_loss_per_sent = loss_per_sent / squared_lengths
            batch_loss = th.sum(normalized_loss_per_sent) / total_sents
        else:
            batch_loss = th.tensor(0.0, device=predictions.device)
        return batch_loss, total_sents


class L1DepthLoss(nn.Module):
    """Custom L1 loss for depth sequences."""

    def __init__(self):
        super().__init__()
        self.word_dim = 1

    def forward(self, predictions, label_batch, length_batch):
        """ 
        Computes L1 loss on depth sequences.

        Ignores all entries where label_batch=-1
        Normalizes first within sentences (by dividing by the sentence length)
        and then across the batch.

        Args:
            predictions: A pytorch batch of predicted depths
            label_batch: A pytorch batch of true depths
            length_batch: A pytorch batch of sentence lengths

        Returns:
            A tuple of:
                batch_loss: average loss in the batch
                total_sents: number of sentences in the batch
        """
        total_sents = th.sum(length_batch != 0).float()
        labels_1s = (label_batch != -1).float()
        predictions_masked = predictions * labels_1s
        labels_masked = label_batch * labels_1s
        if total_sents > 0:
            loss_per_sent = th.sum(
                th.abs(predictions_masked - labels_masked), dim=self.word_dim
            )
            normalized_loss_per_sent = loss_per_sent / length_batch.float()
            batch_loss = th.sum(normalized_loss_per_sent) / total_sents
        else:
            batch_loss = th.tensor(0.0, device=predictions.device)
        return batch_loss, total_sents
